{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "7b4aef48-3c3f-4bec-8906-2844d31556ba",
        "showInput": false
      },
      "source": [
        "## Analytic and MC-based Expected Improvement (EI) acquisition\n",
        "\n",
        "In this tutorial, we compare the analytic and MC-based EI acquisition functions and show both `scipy`- and `torch`-based optimizers for optimizing the acquisition. This tutorial highlights the modularity of botorch and the ability to easily try different acquisition functions and accompanying optimization algorithms on the same fitted model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Install dependencies if we are running in colab\n",
        "import sys\n",
        "if 'google.colab' in sys.modules:\n",
        "    %pip install botorch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "6d70e250-da2b-460a-8989-5c048f7a6ce8",
        "showInput": false
      },
      "source": [
        "### Comparison of analytic and MC-based EI\n",
        "Note that we use the analytic and MC variants of the LogEI family of acquisition functions, which remedy numerical issues encountered in the naive implementations. See https://arxiv.org/pdf/2310.20708 for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649205799,
        "executionStopTime": 1668649205822,
        "originalKey": "f678d607-be4c-4f37-aed5-3597158432ce",
        "output": {
          "id": 8143993305683446,
          "loadingStatus": "loaded"
        },
        "requestMsgId": "0aae9d3f-d796-4a18-a4aa-b015b5b582ac"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "from botorch.fit import fit_gpytorch_mll\n",
        "from botorch.models import SingleTaskGP\n",
        "from botorch.test_functions import Hartmann\n",
        "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
        "\n",
        "neg_hartmann6 = Hartmann(dim=6, negate=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "332a3aeb-187f-4265-b52f-4ca7bdcf5e4a",
        "showInput": false
      },
      "source": [
        "First, we generate some random data and fit a SingleTaskGP for a 6-dimensional synthetic test function 'Hartmann6'."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649205895,
        "executionStopTime": 1668649206067,
        "originalKey": "a7724f86-8b67-4f70-bf57-f0da79b88f52",
        "output": {
          "id": 1605553740344114,
          "loadingStatus": "loaded"
        },
        "requestMsgId": "25794582-0506-4e89-a112-ba362b7c7e59"
      },
      "outputs": [],
      "source": [
        "torch.manual_seed(seed=12345)  # to keep the data conditions the same\n",
        "dtype = torch.float64\n",
        "train_x = torch.rand(10, 6, dtype=dtype)\n",
        "train_obj = neg_hartmann6(train_x).unsqueeze(-1)\n",
        "model = SingleTaskGP(train_X=train_x, train_Y=train_obj)\n",
        "mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
        "fit_gpytorch_mll(mll);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "576d19e8-2164-4ae9-8416-4ca205d35a77",
        "showInput": false
      },
      "source": [
        "Initialize an analytic EI acquisition function on the fitted model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649206124,
        "executionStopTime": 1668649206138,
        "originalKey": "1d611db5-29ee-4e9b-8fea-fe9d274b5222",
        "requestMsgId": "ba6f0917-f622-486a-818c-bb9ba4580ea5"
      },
      "outputs": [],
      "source": [
        "from botorch.acquisition.analytic import LogExpectedImprovement\n",
        "\n",
        "best_value = train_obj.max()\n",
        "LogEI = LogExpectedImprovement(model=model, best_f=best_value)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "e86567ac-e3f1-46ec-889a-6cc21f2fc918",
        "showInput": false
      },
      "source": [
        "Next, we optimize the analytic EI acquisition function using 50 random restarts chosen from 100 initial raw samples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649206218,
        "executionStopTime": 1668649206938,
        "originalKey": "dc5613c6-2f99-4193-8956-6e710fee5fa2",
        "output": {
          "id": 422599616946465,
          "loadingStatus": "loaded"
        },
        "requestMsgId": "3df2fc12-7f4c-4abb-b1d2-90bb3b8bf05c"
      },
      "outputs": [],
      "source": [
        "from botorch.optim import optimize_acqf\n",
        "\n",
        "new_point_analytic, _ = optimize_acqf(\n",
        "    acq_function=LogEI,\n",
        "    bounds=torch.tensor([[0.0] * 6, [1.0] * 6]),\n",
        "    q=1,\n",
        "    num_restarts=20,\n",
        "    raw_samples=100,\n",
        "    options={},\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649206992,
        "executionStopTime": 1668649207011,
        "originalKey": "76fb19a3-c2c2-451a-8c0b-50cb14c55460",
        "requestMsgId": "a5cbada9-0b7c-41a2-934f-10d9bbe2e316"
      },
      "outputs": [],
      "source": [
        "# NOTE: The acquisition value here is the log of the expected improvement.\n",
        "LogEI(new_point_analytic), new_point_analytic"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "99862767-916b-4de0-881a-2eafc819f356",
        "showInput": false
      },
      "source": [
        "Now, let's swap out the analytic acquisition function and replace it with an MC version. Note that we are in the `q = 1` case; for `q > 1`, an analytic version does not exist."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649207083,
        "executionStopTime": 1668649207929,
        "originalKey": "aaf04cba-3716-4fbd-8baa-2c75dd068860",
        "output": {
          "id": 495747073400348,
          "loadingStatus": "loaded"
        },
        "requestMsgId": "0e7691f2-34c7-43df-a247-7f7ba95220f1"
      },
      "outputs": [],
      "source": [
        "from botorch.acquisition.logei import qLogExpectedImprovement\n",
        "from botorch.sampling import SobolQMCNormalSampler\n",
        "\n",
        "\n",
        "sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]), seed=0)\n",
        "MC_LogEI = qLogExpectedImprovement(model, best_f=best_value, sampler=sampler, fat=False)\n",
        "torch.manual_seed(seed=0)  # to keep the restart conditions the same\n",
        "new_point_mc, _ = optimize_acqf(\n",
        "    acq_function=MC_LogEI,\n",
        "    bounds=torch.tensor([[0.0] * 6, [1.0] * 6]),\n",
        "    q=1,\n",
        "    num_restarts=20,\n",
        "    raw_samples=100,\n",
        "    options={},\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649207976,
        "executionStopTime": 1668649207989,
        "originalKey": "73ffa9ea-3cff-46eb-91ea-b2f75fdb07f2",
        "requestMsgId": "b780cff4-6e90-4e39-8558-b04136e71e94"
      },
      "outputs": [],
      "source": [
        "# NOTE: The acquisition value here is the log of the expected improvement.\n",
        "MC_LogEI(new_point_mc), new_point_mc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "41f37b1a-df0a-4bb9-8996-fdddd3fe298c",
        "showInput": false
      },
      "source": [
        "Check that the two generated points are close."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649208035,
        "executionStopTime": 1668649208043,
        "originalKey": "c5c20ba9-82af-4d07-832f-86ede74f8959",
        "requestMsgId": "0b3db1ad-6ddb-4f86-9767-e0a486914b33"
      },
      "outputs": [],
      "source": [
        "torch.linalg.norm(new_point_mc - new_point_analytic)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "f45c60db-1318-405d-b95a-6b2723bc5f46",
        "showInput": false
      },
      "source": [
        "### Using a torch optimizer on a stochastic acquisition function\n",
        "\n",
        "We could also optimize using a `torch` optimizer. This is particularly useful for the case of a stochastic acquisition function, which we can obtain by using a `StochasticSampler`. First, we illustrate the usage of `torch.optim.Adam`. In the code snippet below, `gen_batch_initial_candidates` uses a heuristic to select a set of restart locations, `gen_candidates_torch` is a wrapper to the `torch` optimizer for maximizing the acquisition value, and `get_best_candidates` finds the best result amongst the random restarts.\n",
        "\n",
        "Under the hood, `gen_candidates_torch` uses a convergence criterion based on exponential moving averages of the loss. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649208095,
        "executionStopTime": 1668649208252,
        "originalKey": "434e0156-94e1-4aca-a0f6-00f6ffb3a2b0",
        "requestMsgId": "3e15f8c7-ee79-4c95-8583-3bd376229a39"
      },
      "outputs": [],
      "source": [
        "from botorch.sampling.stochastic_samplers import StochasticSampler\n",
        "from botorch.generation import get_best_candidates, gen_candidates_torch\n",
        "from botorch.optim import gen_batch_initial_conditions\n",
        "\n",
        "resampler = StochasticSampler(sample_shape=torch.Size([512]))\n",
        "MC_LogEI_resample = qLogExpectedImprovement(model, best_f=best_value, sampler=resampler)\n",
        "bounds = torch.tensor([[0.0] * 6, [1.0] * 6])\n",
        "\n",
        "batch_initial_conditions = gen_batch_initial_conditions(\n",
        "    acq_function=MC_LogEI_resample,\n",
        "    bounds=bounds,\n",
        "    q=1,\n",
        "    num_restarts=20,\n",
        "    raw_samples=100,\n",
        ")\n",
        "batch_candidates, batch_acq_values = gen_candidates_torch(\n",
        "    initial_conditions=batch_initial_conditions,\n",
        "    acquisition_function=MC_LogEI_resample,\n",
        "    lower_bounds=bounds[0],\n",
        "    upper_bounds=bounds[1],\n",
        "    optimizer=torch.optim.Adam,\n",
        "    options={\"maxiter\": 500},\n",
        ")\n",
        "new_point_torch_Adam = get_best_candidates(\n",
        "    batch_candidates=batch_candidates, batch_values=batch_acq_values\n",
        ").detach()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649208304,
        "executionStopTime": 1668649208320,
        "originalKey": "81c29b36-c663-47e1-8155-ad034c214f53",
        "requestMsgId": "aac6f703-e046-448a-8abe-1742befb9bf9"
      },
      "outputs": [],
      "source": [
        "# NOTE: The acquisition value here is the log of the expected improvement.\n",
        "MC_LogEI_resample(new_point_torch_Adam), new_point_torch_Adam"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649208364,
        "executionStopTime": 1668649208372,
        "originalKey": "17fb0de0-3c5a-414e-9aba-b82710d166c0",
        "requestMsgId": "a13ce358-3ee6-43ad-9e3a-16181a8cdc1e"
      },
      "outputs": [],
      "source": [
        "torch.linalg.norm(new_point_torch_Adam - new_point_analytic)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "originalKey": "9cce59c6-18f5-4d61-ace0-a0d88cd4b8eb",
        "showInput": false
      },
      "source": [
        "By changing the `optimizer` parameter to `gen_candidates_torch`, we can also try `torch.optim.SGD`. Note that without the adaptive step size selection of Adam, basic SGD does worse job at optimizing without further manual tuning of the optimization parameters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649208422,
        "executionStopTime": 1668649208460,
        "originalKey": "dce34cee-fe31-476e-a20d-46237d5861e0",
        "requestMsgId": "47a5a0d7-3281-403d-910a-97054021b811"
      },
      "outputs": [],
      "source": [
        "batch_candidates, batch_acq_values = gen_candidates_torch(\n",
        "    initial_conditions=batch_initial_conditions,\n",
        "    acquisition_function=MC_LogEI_resample,\n",
        "    lower_bounds=bounds[0],\n",
        "    upper_bounds=bounds[1],\n",
        "    optimizer=torch.optim.SGD,\n",
        "    options={\"maxiter\": 500},\n",
        ")\n",
        "new_point_torch_SGD = get_best_candidates(\n",
        "    batch_candidates=batch_candidates, batch_values=batch_acq_values\n",
        ").detach()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649208505,
        "executionStopTime": 1668649208523,
        "originalKey": "350e456d-0d1c-46dc-a618-0fbba9e0a158",
        "requestMsgId": "aa33d42e-c526-4117-88a7-aa3034d82886"
      },
      "outputs": [],
      "source": [
        "MC_LogEI_resample(new_point_torch_SGD), new_point_torch_SGD"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1668649208566,
        "executionStopTime": 1668649208574,
        "originalKey": "e263cfc7-47a0-4b81-ab33-3aa16320c87e",
        "requestMsgId": "3c654fc0-ce64-43c7-a8bf-42935257008a"
      },
      "outputs": [],
      "source": [
        "torch.linalg.norm(new_point_torch_SGD - new_point_analytic)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "python3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
